#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

using System.Net.Sockets;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.Extensions.DependencyInjection;

namespace Grpc.AspNetCore.FunctionalTests.Infrastructure;

public class GrpcTestFixture<TStartup> : IDisposable where TStartup : class
{
    private readonly string _socketPath = Path.Combine(Path.GetTempPath(), "grpc-transporter.tmp");
    private readonly InProcessTestServer _server;

    public GrpcTestFixture(
        Action<IServiceCollection>? initialConfigureServices = null,
        Action<KestrelServerOptions, IDictionary<TestServerEndpointName, string>>? configureKestrel = null,
        TestServerEndpointName? defaultClientEndpointName = null)
    {
        Action<IServiceCollection> configureServices = services =>
        {
            // Registers a service for tests to add new methods
            services.AddSingleton<DynamicGrpcServiceRegistry>();
        };

        _server = new InProcessTestServer<TStartup>(
            services =>
            {
                initialConfigureServices?.Invoke(services);
                configureServices(services);
            },
            (options, urls) =>
            {
                if (configureKestrel != null)
                {
                    configureKestrel(options, urls);
                    return;
                }

                urls[TestServerEndpointName.Http2] = "http://127.0.0.1:50050";
                options.ListenLocalhost(50050, listenOptions =>
                {
                    listenOptions.Protocols = HttpProtocols.Http2;
                });

                urls[TestServerEndpointName.Http1] = "http://127.0.0.1:50040";
                options.ListenLocalhost(50040, listenOptions =>
                {
                    listenOptions.Protocols = HttpProtocols.Http1;
                });

                urls[TestServerEndpointName.Http2WithTls] = "https://127.0.0.1:50030";
                options.ListenLocalhost(50030, listenOptions =>
                {
                    listenOptions.Protocols = HttpProtocols.Http2;

                    var basePath = Path.GetDirectoryName(typeof(InProcessTestServer).Assembly.Location);
                    var certPath = Path.Combine(basePath!, "server1.pfx");
                    listenOptions.UseHttps(certPath, "1111");
                });

                urls[TestServerEndpointName.Http1WithTls] = "https://127.0.0.1:50020";
                options.ListenLocalhost(50020, listenOptions =>
                {
                    listenOptions.Protocols = HttpProtocols.Http1;

                    var basePath = Path.GetDirectoryName(typeof(InProcessTestServer).Assembly.Location);
                    var certPath = Path.Combine(basePath!, "server1.pfx");
                    listenOptions.UseHttps(certPath, "1111");
                });

#if NET5_0_OR_GREATER
                if (File.Exists(_socketPath))
                {
                    File.Delete(_socketPath);
                }

                urls[TestServerEndpointName.UnixDomainSocket] = _socketPath;
                options.ListenUnixSocket(_socketPath, listenOptions =>
                {
                    listenOptions.Protocols = HttpProtocols.Http2;
                });
#endif

#if NET6_0_OR_GREATER
                if (RequireHttp3Attribute.IsSupported(out _))
                {
                    urls[TestServerEndpointName.Http3WithTls] = "https://127.0.0.1:55019";
                    options.ListenLocalhost(55019, listenOptions =>
                    {
#pragma warning disable CA2252 // This API requires opting into preview features
                        // Support HTTP/2 for connectivity health in load balancing to work.
                        listenOptions.Protocols = HttpProtocols.Http2 | HttpProtocols.Http3;
#pragma warning restore CA2252 // This API requires opting into preview features

                        var basePath = Path.GetDirectoryName(typeof(InProcessTestServer).Assembly.Location);
                        var certPath = Path.Combine(basePath!, "server1.pfx");
                        listenOptions.UseHttps(certPath, "1111");
                    });
                }
#endif
            });

        _server.StartServer();

        DynamicGrpc = _server.Host!.Services.GetRequiredService<DynamicGrpcServiceRegistry>();

#if !NET5_0
        AppContext.SetSwitch("System.Net.Http.SocketsHttpHandler.Http2UnencryptedSupport", true);
#endif

        (Client, Handler) = CreateHttpCore(defaultClientEndpointName);
    }

    public DynamicGrpcServiceRegistry DynamicGrpc { get; }

    public HttpMessageHandler Handler { get; }
    public HttpClient Client { get; }

        public HttpClient CreateClient(TestServerEndpointName? endpointName = null, DelegatingHandler? messageHandler = null, Action<SocketsHttpHandler>? configureHandler = null)
    {
            return CreateHttpCore(endpointName, messageHandler, configureHandler).client;
    }

        public (HttpMessageHandler handler, Uri address) CreateHandler(TestServerEndpointName? endpointName = null, DelegatingHandler? messageHandler = null, Action<SocketsHttpHandler>? configureHandler = null)
    {
            var result = CreateHttpCore(endpointName, messageHandler, configureHandler);
        return (result.handler, result.client.BaseAddress!);
    }

        private (HttpClient client, HttpMessageHandler handler) CreateHttpCore(TestServerEndpointName? endpointName = null, DelegatingHandler? messageHandler = null, Action<SocketsHttpHandler>? configureHandler = null)
    {
#if HTTP3_TESTING
        endpointName ??= TestServerEndpointName.Http3WithTls;
#else
        endpointName ??= TestServerEndpointName.Http2;
#endif

        var socketsHttpHandler = new SocketsHttpHandler();
        socketsHttpHandler.SslOptions = new System.Net.Security.SslClientAuthenticationOptions
        {
            RemoteCertificateValidationCallback = (_, __, ___, ____) => true
        };

            configureHandler?.Invoke(socketsHttpHandler);

#if NET5_0_OR_GREATER
        if (endpointName == TestServerEndpointName.UnixDomainSocket)
        {
            var udsEndPoint = new UnixDomainSocketEndPoint(_server.GetUrl(endpointName.Value));
            var connectionFactory = new UnixDomainSocketConnectionFactory(udsEndPoint);

            socketsHttpHandler.ConnectCallback = connectionFactory.ConnectAsync;
        }
#endif

        HttpClient client;
        HttpMessageHandler handler;
        if (messageHandler != null)
        {
            messageHandler.InnerHandler = socketsHttpHandler;
            handler = messageHandler;
        }
        else
        {
            handler = socketsHttpHandler;
        }

#if NET6_0_OR_GREATER
        if (endpointName == TestServerEndpointName.Http3WithTls)
        {
            // TODO(JamesNK): There is a bug with SocketsHttpHandler and HTTP/3 that prevents calls
            // upgrading from 2 to 3. Force HTTP/3 calls to require that protocol.
            handler = new Http3DelegatingHandler(handler);
        }
#endif

        client = new HttpClient(handler);

        if (endpointName == TestServerEndpointName.Http2)
        {
            client.DefaultRequestVersion = new Version(2, 0);
#if NET5_0_OR_GREATER
            client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
#endif
        }

        client.BaseAddress = CalculateBaseAddress(endpointName.Value);

        return (client, handler);
    }

    private Uri CalculateBaseAddress(TestServerEndpointName endpointName)
    {
#if NET5_0_OR_GREATER
        if (endpointName == TestServerEndpointName.UnixDomainSocket)
        {
            return new Uri("http://localhost");
        }
#endif

        return new Uri(_server.GetUrl(endpointName));
    }

    public Uri GetUrl(TestServerEndpointName endpointName)
    {
        switch (endpointName)
        {
            case TestServerEndpointName.Http1:
            case TestServerEndpointName.Http2:
            case TestServerEndpointName.Http1WithTls:
            case TestServerEndpointName.Http2WithTls:
#if NET6_0_OR_GREATER
            case TestServerEndpointName.Http3WithTls:
#endif
                return new Uri(_server.GetUrl(endpointName));
#if NET5_0_OR_GREATER
            case TestServerEndpointName.UnixDomainSocket:
                return new Uri("http://localhost");
#endif
            default:
                throw new ArgumentException("Unexpected value: " + endpointName, nameof(endpointName));
        }
    }

    internal event Action<LogRecord> ServerLogged
    {
        add => _server.ServerLogged += value;
        remove => _server.ServerLogged -= value;
    }

    public void Dispose()
    {
        Client.Dispose();
        _server.Dispose();
    }

#if NET6_0_OR_GREATER
    private class Http3DelegatingHandler : DelegatingHandler
    {
        public Http3DelegatingHandler(HttpMessageHandler innerHandler)
        {
            InnerHandler = innerHandler;
        }

        protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
        {
            request.Version = new Version(3, 0);
            request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
            return base.SendAsync(request, cancellationToken);
        }
    }
#endif
}
